# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""

import numpy as np

from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from ..utils.timer_wrapper import tensorflow_op_timer



def _AddTest(test_class, op_name, testcase_name, fn):
    test_name = "_".join(["test", op_name, testcase_name])
    if hasattr(test_class, test_name):
        raise RuntimeError("Test %s defined more than once" % test_name)
    setattr(test_class, test_name, fn)


class SvdOpTest(test.TestCase):

    # @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    # def testWrongDimensions(self):
    #     # The input to svd should be a tensor of at least rank 2.
    #     scalar = constant_op.constant(1.)
    #     with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
    #                                 "rank.* 2.*0"):
    #         linalg_ops.svd(scalar)
    #     vector = constant_op.constant([1., 2.])
    #     with self.assertRaisesRegex((ValueError, errors_impl.InvalidArgumentError),
    #                                 "rank.* 2.*1"):
    #         linalg_ops.svd(vector)

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def testThrowDeterminismError(self):
        shape = [6, 1]
        seed = [42, 24]
        matrix = stateless_random_ops.stateless_random_normal(shape, seed)
        with test_util.deterministic_ops():
            if test_util.is_gpu_available(cuda_only=True):
                with self.assertRaisesRegex(
                    errors_impl.UnimplementedError,
                    "Determinism is not yet supported for SVD of matrices with 1 column."
                ):
                    self.evaluate(linalg_ops.svd(matrix))

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def testDeterminism(self):
        shape = [6, 5]
        seed = [42, 24]
        matrix = stateless_random_ops.stateless_random_normal(shape, seed)
        with test_util.deterministic_ops():
            if test_util.is_gpu_available(cuda_only=True):
                timer = tensorflow_op_timer()
                with timer:
                    test = linalg_ops.svd(matrix)
                    timer.gen.send(test)
                s1, u1, v1 = self.evaluate(linalg_ops.svd(matrix))
                for _ in range(5):
                    timer = tensorflow_op_timer()
                    with timer:
                        test = linalg_ops.svd(matrix)
                        timer.gen.send(test)
                    s2, u2, v2 = self.evaluate(linalg_ops.svd(matrix))
                    self.assertAllEqual(s1, s2)
                    self.assertAllEqual(u1, u2)
                    self.assertAllEqual(v1, v2)

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def DISABLED_testBadInputs(self):
        # TODO(b/185822300): re-enable after the bug is fixed in CUDA-11.x
        # The input to svd should be a tensor of at least rank 2.
        for bad_val in [np.nan, np.inf]:
            matrix = np.array([[1, bad_val], [0, 1]])
            timer = tensorflow_op_timer()
            with timer:
                s, u, v = linalg_ops.svd(matrix, compute_uv=True)
                timer.gen.send(s,)
            s, u, v = self.evaluate([s, u, v])
            for i in range(2):
                self.assertTrue(np.isnan(s[i]))
                for j in range(2):
                    self.assertTrue(np.isnan(u[i, j]))
                    self.assertTrue(np.isnan(v[i, j]))

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def testExecuteMultipleWithoutError(self):
        all_ops = []
        shape = [6, 5]
        seed = [42, 24]
        for compute_uv_ in True, False:
            for full_matrices_ in True, False:
                matrix1 = stateless_random_ops.stateless_random_normal(
                    shape, seed)
                matrix2 = stateless_random_ops.stateless_random_normal(
                    shape, seed)
                self.assertAllEqual(matrix1, matrix2)
                if compute_uv_:
                    timer = tensorflow_op_timer()
                    with timer:
                        s1, u1, v1 = linalg_ops.svd(
                        matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
                        timer.gen.send(s1)
                    s2, u2, v2 = linalg_ops.svd(
                        matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
                    all_ops += [s1, s2, u1, u2, v1, v2]
                else:
                    timer = tensorflow_op_timer()
                    with timer:
                        s1 = linalg_ops.svd(
                        matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
                        timer.gen.send(s1)
                    timer = tensorflow_op_timer()
                    with timer:
                        s2 = linalg_ops.svd(
                        matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
                        timer.gen.send(s2)
                    all_ops += [s1, s2]
        val = self.evaluate(all_ops)
        for i in range(0, len(val), 2):
            self.assertAllEqual(val[i], val[i + 1])

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def testEmptyBatches(self):
        matrices = constant_op.constant(1.0, shape=[0, 2, 2])
        timer = tensorflow_op_timer()
        with timer: 
            test = linalg_ops.svd(matrices)
            timer.gen.send(test)
        s, u, v = self.evaluate(linalg_ops.svd(matrices))
        self.assertAllEqual(s, np.zeros([0, 2]))
        self.assertAllEqual(u, np.zeros([0, 2, 2]))
        self.assertAllEqual(v, np.zeros([0, 2, 2]))


def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
                  full_matrices_):

    def CompareSingularValues(self, x, y, tol):
        atol = (x[0] + y[0]) * tol if len(x) else tol
        self.assertAllClose(x, y, atol=atol)

    def CompareSingularVectors(self, x, y, rank, tol):
        # We only compare the first 'rank' singular vectors since the
        # remainder form an arbitrary orthonormal basis for the
        # (row- or column-) null space, whose exact value depends on
        # implementation details. Notice that since we check that the
        # matrices of singular vectors are unitary elsewhere, we do
        # implicitly test that the trailing vectors of x and y span the
        # same space.
        x = x[..., 0:rank]
        y = y[..., 0:rank]
        # Singular vectors are only unique up to sign (complex phase factor for
        # complex matrices), so we normalize the sign first.
        sum_of_ratios = np.sum(np.divide(y, x), -2, keepdims=True)
        phases = np.divide(sum_of_ratios, np.abs(sum_of_ratios))
        x *= phases
        self.assertAllClose(x, y, atol=2 * tol)

    def CheckApproximation(self, a, u, s, v, full_matrices_, tol):
        # Tests that a ~= u*diag(s)*transpose(v).
        batch_shape = a.shape[:-2]
        m = a.shape[-2]
        n = a.shape[-1]
        diag_s = math_ops.cast(array_ops.matrix_diag(s), dtype=dtype_)
        if full_matrices_:
            if m > n:
                zeros = array_ops.zeros(batch_shape + (m - n, n), dtype=dtype_)
                diag_s = array_ops.concat([diag_s, zeros], a.ndim - 2)
            elif n > m:
                zeros = array_ops.zeros(batch_shape + (m, n - m), dtype=dtype_)
                diag_s = array_ops.concat([diag_s, zeros], a.ndim - 1)
        a_recon = math_ops.matmul(u, diag_s)
        a_recon = math_ops.matmul(a_recon, v, adjoint_b=True)
        self.assertAllClose(a_recon, a, rtol=tol, atol=tol)

    def CheckUnitary(self, x, tol):
        # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
        xx = math_ops.matmul(x, x, adjoint_a=True)
        identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
        self.assertAllClose(identity, xx, atol=tol)

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def Test(self):
        if not use_static_shape_ and context.executing_eagerly():
            return
        is_complex = dtype_ in (np.complex64, np.complex128)
        is_single = dtype_ in (np.float32, np.complex64)
        tol = 3e-4 if is_single else 1e-12
        if test.is_gpu_available():
            # The gpu version returns results that are much less accurate.
            tol *= 200
        np.random.seed(42)
        x_np = np.random.uniform(
            low=-1.0, high=1.0, size=np.prod(shape_)).reshape(shape_).astype(dtype_)
        if is_complex:
            x_np += 1j * np.random.uniform(
                low=-1.0, high=1.0,
                size=np.prod(shape_)).reshape(shape_).astype(dtype_)

        if use_static_shape_:
            x_tf = constant_op.constant(x_np)
        else:
            x_tf = array_ops.placeholder(dtype_)

        if compute_uv_:
            timer = tensorflow_op_timer()
            with timer: 
                s_tf, u_tf, v_tf = linalg_ops.svd(
                x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
                timer.gen.send(s_tf)
            if use_static_shape_:
                s_tf_val, u_tf_val, v_tf_val = self.evaluate(
                    [s_tf, u_tf, v_tf])
            else:
                with self.session() as sess:
                    s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf],
                                                            feed_dict={x_tf: x_np})
        else:
            timer = tensorflow_op_timer()
            with timer: 
                s_tf = linalg_ops.svd(
                x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
                timer.gen.send(s_tf)
            if use_static_shape_:
                s_tf_val = self.evaluate(s_tf)
            else:
                with self.session() as sess:
                    s_tf_val = sess.run(s_tf, feed_dict={x_tf: x_np})

        if compute_uv_:
            u_np, s_np, v_np = np.linalg.svd(
                x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
        else:
            s_np = np.linalg.svd(
                x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
        # We explicitly avoid the situation where numpy eliminates a first
        # dimension that is equal to one.
        s_np = np.reshape(s_np, s_tf_val.shape)

        CompareSingularValues(self, s_np, s_tf_val, tol)
        if compute_uv_:
            CompareSingularVectors(self, u_np, u_tf_val, min(shape_[-2:]), tol)
            CompareSingularVectors(self, np.conj(np.swapaxes(v_np, -2, -1)), v_tf_val,
                                   min(shape_[-2:]), tol)
            CheckApproximation(self, x_np, u_tf_val, s_tf_val, v_tf_val,
                               full_matrices_, tol)
            CheckUnitary(self, u_tf_val, tol)
            CheckUnitary(self, v_tf_val, tol)

    return Test


class SvdGradOpTest(test.TestCase):
    pass  # Filled in below


def _NormalizingSvd(tf_a, full_matrices_):
    timer = tensorflow_op_timer()
    with timer: 
        tf_s, tf_u, tf_v = linalg_ops.svd(
        tf_a, compute_uv=True, full_matrices=full_matrices_)
        timer.gen.send(tf_s)
    # Singular vectors are only unique up to an arbitrary phase. We normalize
    # the vectors such that the first component of u (if m >=n) or v (if n > m)
    # have phase 0.
    m = tf_a.shape[-2]
    n = tf_a.shape[-1]
    if m >= n:
        top_rows = tf_u[..., 0:1, :]
    else:
        top_rows = tf_v[..., 0:1, :]
    if tf_u.dtype.is_complex:
        angle = -math_ops.angle(top_rows)
        phase = math_ops.complex(math_ops.cos(angle), math_ops.sin(angle))
    else:
        phase = math_ops.sign(top_rows)
    tf_u *= phase[..., :m]
    tf_v *= phase[..., :n]
    return tf_s, tf_u, tf_v


def _GetSvdGradOpTest(dtype_, shape_, compute_uv_, full_matrices_):

    @test_util.run_in_graph_and_eager_modes(use_gpu=True)
    def Test(self):

        def RandomInput():
            np.random.seed(42)
            a = np.random.uniform(low=-1.0, high=1.0,
                                  size=shape_).astype(dtype_)
            if dtype_ in [np.complex64, np.complex128]:
                a += 1j * np.random.uniform(
                    low=-1.0, high=1.0, size=shape_).astype(dtype_)
            return a

        # Optimal stepsize for central difference is O(epsilon^{1/3}).
        # See Equation (21) in:
        # http://www.karenkopecky.net/Teaching/eco613614/Notes_NumericalDifferentiation.pdf
        # TODO(rmlarsen): Move step size control to gradient checker.
        epsilon = np.finfo(dtype_).eps
        delta = 0.25 * epsilon**(1.0 / 3.0)
        if dtype_ in [np.float32, np.complex64]:
            tol = 3e-2
        else:
            tol = 1e-6
        if compute_uv_:
            funcs = [
                lambda a: _NormalizingSvd(a, full_matrices_)[0],
                lambda a: _NormalizingSvd(a, full_matrices_)[1],
                lambda a: _NormalizingSvd(a, full_matrices_)[2]
            ]
        else:
            funcs = [lambda a: linalg_ops.svd(a, compute_uv=False)]

        for f in funcs:
            theoretical, numerical = gradient_checker_v2.compute_gradient(
                f, [RandomInput()], delta=delta)
            self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)

    return Test


class SvdGradGradOpTest(test.TestCase):
    pass  # Filled in below


def _GetSvdGradGradOpTest(dtype_, shape_, compute_uv_, full_matrices_):

    @test_util.run_v1_only("b/120545219")
    def Test(self):
        np.random.seed(42)
        a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
        if dtype_ in [np.complex64, np.complex128]:
            a += 1j * np.random.uniform(
                low=-1.0, high=1.0, size=shape_).astype(dtype_)
        # Optimal stepsize for central difference is O(epsilon^{1/3}).
        # See Equation (21) in:
        # http://www.karenkopecky.net/Teaching/eco613614/Notes_NumericalDifferentiation.pdf
        # TODO(rmlarsen): Move step size control to gradient checker.
        epsilon = np.finfo(dtype_).eps
        delta = 0.1 * epsilon**(1.0 / 3.0)
        tol = 1e-5
        with self.session():
            tf_a = constant_op.constant(a)
            if compute_uv_:
                tf_s, tf_u, tf_v = _NormalizingSvd(tf_a, full_matrices_)
                outputs = [tf_s, tf_u, tf_v]
            else:
                timer = tensorflow_op_timer()
                with timer: 
                    tf_s = linalg_ops.svd(tf_a, compute_uv=False)
                    timer.gen.send(tf_s)
                outputs = [tf_s]
            outputs_sums = [math_ops.reduce_sum(o) for o in outputs]
            tf_func_outputs = math_ops.add_n(outputs_sums)
            grad = gradients_impl.gradients(tf_func_outputs, tf_a)[0]
            x_init = np.random.uniform(
                low=-1.0, high=1.0, size=shape_).astype(dtype_)
            if dtype_ in [np.complex64, np.complex128]:
                x_init += 1j * np.random.uniform(
                    low=-1.0, high=1.0, size=shape_).astype(dtype_)
            theoretical, numerical = gradient_checker.compute_gradient(
                tf_a,
                tf_a.get_shape().as_list(),
                grad,
                grad.get_shape().as_list(),
                x_init_value=x_init,
                delta=delta)
            self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)

    return Test


class SVDBenchmark(test.Benchmark):

    shapes = [
        (4, 4),
        (8, 8),
        (16, 16),
        (101, 101),
        (256, 256),
        (1024, 1024),
        (2048, 2048),
        (1, 8, 8),
        (10, 8, 8),
        (100, 8, 8),
        (1000, 8, 8),
        (1, 32, 32),
        (10, 32, 32),
        (100, 32, 32),
        (1000, 32, 32),
        (1, 256, 256),
        (10, 256, 256),
        (100, 256, 256),
    ]

    def benchmarkSVDOp(self):
        for shape_ in self.shapes:
            with ops.Graph().as_default(), \
                    session.Session(config=benchmark.benchmark_config()) as sess, \
                    ops.device("/cpu:0"):
                matrix_value = np.random.uniform(
                    low=-1.0, high=1.0, size=shape_).astype(np.float32)
                matrix = variables.Variable(matrix_value)
                timer = tensorflow_op_timer()
                with timer: 
                    u, s, v = linalg_ops.svd(matrix)
                    timer.gen.send(u)
                self.evaluate(variables.global_variables_initializer())
                self.run_op_benchmark(
                    sess,
                    control_flow_ops.group(u, s, v),
                    min_iters=25,
                    name="SVD_cpu_{shape}".format(shape=shape_))

            if test.is_gpu_available(True):
                with ops.Graph().as_default(), \
                        session.Session(config=benchmark.benchmark_config()) as sess, \
                        ops.device("/device:GPU:0"):
                    matrix_value = np.random.uniform(
                        low=-1.0, high=1.0, size=shape_).astype(np.float32)
                    matrix = variables.Variable(matrix_value)
                    timer = tensorflow_op_timer()
                    with timer: 
                        u, s, v = linalg_ops.svd(matrix)
                        timer.gen.send(u)
                    self.evaluate(variables.global_variables_initializer())
                    self.run_op_benchmark(
                        sess,
                        control_flow_ops.group(u, s, v),
                        min_iters=25,
                        name="SVD_gpu_{shape}".format(shape=shape_))


if __name__ == "__main__":
    dtypes_to_test = [np.float32, np.float64, np.complex64, np.complex128]
    for compute_uv in False, True:
        for full_matrices in False, True:
            for dtype in dtypes_to_test:
                for rows in 0, 1, 2, 5, 10, 32, 100:
                    for cols in 0, 1, 2, 5, 10, 32, 100:
                        for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
                            full_shape = batch_dims + (rows, cols)
                            for use_static_shape in set([True, False]):
                                name = "%s_%s_static_shape_%s__compute_uv_%s_full_%s" % (
                                    dtype.__name__, "_".join(
                                        map(str, full_shape)),
                                    use_static_shape, compute_uv, full_matrices)
                                _AddTest(
                                    SvdOpTest, "Svd", name,
                                    _GetSvdOpTest(dtype, full_shape, use_static_shape,
                                                  compute_uv, full_matrices))
    for compute_uv in False, True:
        for full_matrices in False, True:
            dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
                      (not compute_uv))
            for dtype in dtypes:
                mat_shapes = [(10, 11), (11, 10), (11, 11), (2, 2, 2, 3)]
                if not full_matrices or not compute_uv:
                    mat_shapes += [(5, 11), (11, 5)]
                for mat_shape in mat_shapes:
                    for batch_dims in [(), (3,)]:
                        full_shape = batch_dims + mat_shape
                        name = "%s_%s_compute_uv_%s_full_%s" % (dtype.__name__, "_".join(
                            map(str, full_shape)), compute_uv, full_matrices)
                        _AddTest(
                            SvdGradOpTest, "SvdGrad", name,
                            _GetSvdGradOpTest(dtype, full_shape, compute_uv, full_matrices))
                        # The results are too inaccurate for float32.
                        if dtype in (np.float64, np.complex128):
                            _AddTest(
                                SvdGradGradOpTest, "SvdGradGrad", name,
                                _GetSvdGradGradOpTest(dtype, full_shape, compute_uv,
                                                      full_matrices))
    test.main()
